{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "name": "BERT-Torch",
      "provenance": [],
      "collapsed_sections": []
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "metadata": {
        "id": "j2KIji9Ts0Wn",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "'''\n",
        "  code by Tae Hwan Jung(Jeff Jung) @graykode, modify by wmathor\n",
        "  Reference : https://github.com/jadore801120/attention-is-all-you-need-pytorch\n",
        "         https://github.com/JayParks/transformer, https://github.com/dhlee347/pytorchic-bert\n",
        "'''\n",
        "import re\n",
        "import math\n",
        "import torch\n",
        "import numpy as np\n",
        "from random import *\n",
        "import torch.nn as nn\n",
        "import torch.optim as optim\n",
        "import torch.utils.data as Data\n",
        "\n",
        "text = (\n",
        "    'Hello, how are you? I am Romeo.\\n' # R\n",
        "    'Hello, Romeo My name is Juliet. Nice to meet you.\\n' # J\n",
        "    'Nice meet you too. How are you today?\\n' # R\n",
        "    'Great. My baseball team won the competition.\\n' # J\n",
        "    'Oh Congratulations, Juliet\\n' # R\n",
        "    'Thank you Romeo\\n' # J\n",
        "    'Where are you going today?\\n' # R\n",
        "    'I am going shopping. What about you?\\n' # J\n",
        "    'I am going to visit my grandmother. she is not very well' # R\n",
        ")\n",
        "sentences = re.sub(\"[.,!?\\\\-]\", '', text.lower()).split('\\n') # filter '.', ',', '?', '!'\n",
        "word_list = list(set(\" \".join(sentences).split())) # ['hello', 'how', 'are', 'you',...]\n",
        "word2idx = {'[PAD]' : 0, '[CLS]' : 1, '[SEP]' : 2, '[MASK]' : 3}\n",
        "for i, w in enumerate(word_list):\n",
        "    word2idx[w] = i + 4\n",
        "idx2word = {i: w for i, w in enumerate(word2idx)}\n",
        "vocab_size = len(word2idx)\n",
        "\n",
        "token_list = list()\n",
        "for sentence in sentences:\n",
        "    arr = [word2idx[s] for s in sentence.split()]\n",
        "    token_list.append(arr)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "-AuXO3rQJIUj",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# BERT Parameters\n",
        "maxlen = 30\n",
        "batch_size = 6\n",
        "max_pred = 5 # max tokens of prediction\n",
        "n_layers = 6\n",
        "n_heads = 12\n",
        "d_model = 768\n",
        "d_ff = 768*4 # 4*d_model, FeedForward dimension\n",
        "d_k = d_v = 64  # dimension of K(=Q), V\n",
        "n_segments = 2"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bOXYOwFAJH93",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "# sample IsNext and NotNext to be same in small batch size\n",
        "def make_data():\n",
        "    batch = []\n",
        "    positive = negative = 0\n",
        "    while positive != batch_size/2 or negative != batch_size/2:\n",
        "        tokens_a_index, tokens_b_index = randrange(len(sentences)), randrange(len(sentences)) # sample random index in sentences\n",
        "        tokens_a, tokens_b = token_list[tokens_a_index], token_list[tokens_b_index]\n",
        "        input_ids = [word2idx['[CLS]']] + tokens_a + [word2idx['[SEP]']] + tokens_b + [word2idx['[SEP]']]\n",
        "        segment_ids = [0] * (1 + len(tokens_a) + 1) + [1] * (len(tokens_b) + 1)\n",
        "\n",
        "        # MASK LM\n",
        "        n_pred =  min(max_pred, max(1, int(len(input_ids) * 0.15))) # 15 % of tokens in one sentence\n",
        "        cand_maked_pos = [i for i, token in enumerate(input_ids)\n",
        "                          if token != word2idx['[CLS]'] and token != word2idx['[SEP]']] # candidate masked position\n",
        "        shuffle(cand_maked_pos)\n",
        "        masked_tokens, masked_pos = [], []\n",
        "        for pos in cand_maked_pos[:n_pred]:\n",
        "            masked_pos.append(pos)\n",
        "            masked_tokens.append(input_ids[pos])\n",
        "            if random() < 0.8:  # 80%\n",
        "                input_ids[pos] = word2idx['[MASK]'] # make mask\n",
        "            elif random() > 0.9:  # 10%\n",
        "                index = randint(0, vocab_size - 1) # random index in vocabulary\n",
        "                while index < 4: # can't involve 'CLS', 'SEP', 'PAD'\n",
        "                  index = randint(0, vocab_size - 1)\n",
        "                input_ids[pos] = index # replace\n",
        "\n",
        "        # Zero Paddings\n",
        "        n_pad = maxlen - len(input_ids)\n",
        "        input_ids.extend([0] * n_pad)\n",
        "        segment_ids.extend([0] * n_pad)\n",
        "\n",
        "        # Zero Padding (100% - 15%) tokens\n",
        "        if max_pred > n_pred:\n",
        "            n_pad = max_pred - n_pred\n",
        "            masked_tokens.extend([0] * n_pad)\n",
        "            masked_pos.extend([0] * n_pad)\n",
        "\n",
        "        if tokens_a_index + 1 == tokens_b_index and positive < batch_size/2:\n",
        "            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, True]) # IsNext\n",
        "            positive += 1\n",
        "        elif tokens_a_index + 1 != tokens_b_index and negative < batch_size/2:\n",
        "            batch.append([input_ids, segment_ids, masked_tokens, masked_pos, False]) # NotNext\n",
        "            negative += 1\n",
        "    return batch\n",
        "# Proprecessing Finished"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "bqxycRhzia7r",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "batch = make_data()\n",
        "input_ids, segment_ids, masked_tokens, masked_pos, isNext = zip(*batch)\n",
        "input_ids, segment_ids, masked_tokens, masked_pos, isNext = \\\n",
        "    torch.LongTensor(input_ids),  torch.LongTensor(segment_ids), torch.LongTensor(masked_tokens),\\\n",
        "    torch.LongTensor(masked_pos), torch.LongTensor(isNext)\n",
        "\n",
        "class MyDataSet(Data.Dataset):\n",
        "  def __init__(self, input_ids, segment_ids, masked_tokens, masked_pos, isNext):\n",
        "    self.input_ids = input_ids\n",
        "    self.segment_ids = segment_ids\n",
        "    self.masked_tokens = masked_tokens\n",
        "    self.masked_pos = masked_pos\n",
        "    self.isNext = isNext\n",
        "  \n",
        "  def __len__(self):\n",
        "    return len(self.input_ids)\n",
        "  \n",
        "  def __getitem__(self, idx):\n",
        "    return self.input_ids[idx], self.segment_ids[idx], self.masked_tokens[idx], self.masked_pos[idx], self.isNext[idx]\n",
        "\n",
        "loader = Data.DataLoader(MyDataSet(input_ids, segment_ids, masked_tokens, masked_pos, isNext), batch_size, True)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "6inMS744xRwh",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        "def get_attn_pad_mask(seq_q, seq_k):\n",
        "    batch_size, seq_len = seq_q.size()\n",
        "    # eq(zero) is PAD token\n",
        "    pad_attn_mask = seq_q.data.eq(0).unsqueeze(1)  # [batch_size, 1, seq_len]\n",
        "    return pad_attn_mask.expand(batch_size, seq_len, seq_len)  # [batch_size, seq_len, seq_len]\n",
        "\n",
        "def gelu(x):\n",
        "    \"\"\"\n",
        "      Implementation of the gelu activation function.\n",
        "      For information: OpenAI GPT's gelu is slightly different (and gives slightly different results):\n",
        "      0.5 * x * (1 + torch.tanh(math.sqrt(2 / math.pi) * (x + 0.044715 * torch.pow(x, 3))))\n",
        "      Also see https://arxiv.org/abs/1606.08415\n",
        "    \"\"\"\n",
        "    return x * 0.5 * (1.0 + torch.erf(x / math.sqrt(2.0)))\n",
        "\n",
        "class Embedding(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(Embedding, self).__init__()\n",
        "        self.tok_embed = nn.Embedding(vocab_size, d_model)  # token embedding\n",
        "        self.pos_embed = nn.Embedding(maxlen, d_model)  # position embedding\n",
        "        self.seg_embed = nn.Embedding(n_segments, d_model)  # segment(token type) embedding\n",
        "        self.norm = nn.LayerNorm(d_model)\n",
        "\n",
        "    def forward(self, x, seg):\n",
        "        seq_len = x.size(1)\n",
        "        pos = torch.arange(seq_len, dtype=torch.long)\n",
        "        pos = pos.unsqueeze(0).expand_as(x)  # [seq_len] -> [batch_size, seq_len]\n",
        "        embedding = self.tok_embed(x) + self.pos_embed(pos) + self.seg_embed(seg)\n",
        "        return self.norm(embedding)\n",
        "\n",
        "class ScaledDotProductAttention(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(ScaledDotProductAttention, self).__init__()\n",
        "\n",
        "    def forward(self, Q, K, V, attn_mask):\n",
        "        scores = torch.matmul(Q, K.transpose(-1, -2)) / np.sqrt(d_k) # scores : [batch_size, n_heads, seq_len, seq_len]\n",
        "        scores.masked_fill_(attn_mask, -1e9) # Fills elements of self tensor with value where mask is one.\n",
        "        attn = nn.Softmax(dim=-1)(scores)\n",
        "        context = torch.matmul(attn, V)\n",
        "        return context\n",
        "\n",
        "class MultiHeadAttention(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(MultiHeadAttention, self).__init__()\n",
        "        self.W_Q = nn.Linear(d_model, d_k * n_heads)\n",
        "        self.W_K = nn.Linear(d_model, d_k * n_heads)\n",
        "        self.W_V = nn.Linear(d_model, d_v * n_heads)\n",
        "    def forward(self, Q, K, V, attn_mask):\n",
        "        # q: [batch_size, seq_len, d_model], k: [batch_size, seq_len, d_model], v: [batch_size, seq_len, d_model]\n",
        "        residual, batch_size = Q, Q.size(0)\n",
        "        # (B, S, D) -proj-> (B, S, D) -split-> (B, S, H, W) -trans-> (B, H, S, W)\n",
        "        q_s = self.W_Q(Q).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # q_s: [batch_size, n_heads, seq_len, d_k]\n",
        "        k_s = self.W_K(K).view(batch_size, -1, n_heads, d_k).transpose(1,2)  # k_s: [batch_size, n_heads, seq_len, d_k]\n",
        "        v_s = self.W_V(V).view(batch_size, -1, n_heads, d_v).transpose(1,2)  # v_s: [batch_size, n_heads, seq_len, d_v]\n",
        "\n",
        "        attn_mask = attn_mask.unsqueeze(1).repeat(1, n_heads, 1, 1) # attn_mask : [batch_size, n_heads, seq_len, seq_len]\n",
        "\n",
        "        # context: [batch_size, n_heads, seq_len, d_v], attn: [batch_size, n_heads, seq_len, seq_len]\n",
        "        context = ScaledDotProductAttention()(q_s, k_s, v_s, attn_mask)\n",
        "        context = context.transpose(1, 2).contiguous().view(batch_size, -1, n_heads * d_v) # context: [batch_size, seq_len, n_heads, d_v]\n",
        "        output = nn.Linear(n_heads * d_v, d_model)(context)\n",
        "        return nn.LayerNorm(d_model)(output + residual) # output: [batch_size, seq_len, d_model]\n",
        "\n",
        "class PoswiseFeedForwardNet(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(PoswiseFeedForwardNet, self).__init__()\n",
        "        self.fc1 = nn.Linear(d_model, d_ff)\n",
        "        self.fc2 = nn.Linear(d_ff, d_model)\n",
        "\n",
        "    def forward(self, x):\n",
        "        # (batch_size, seq_len, d_model) -> (batch_size, seq_len, d_ff) -> (batch_size, seq_len, d_model)\n",
        "        return self.fc2(gelu(self.fc1(x)))\n",
        "\n",
        "class EncoderLayer(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(EncoderLayer, self).__init__()\n",
        "        self.enc_self_attn = MultiHeadAttention()\n",
        "        self.pos_ffn = PoswiseFeedForwardNet()\n",
        "\n",
        "    def forward(self, enc_inputs, enc_self_attn_mask):\n",
        "        enc_outputs = self.enc_self_attn(enc_inputs, enc_inputs, enc_inputs, enc_self_attn_mask) # enc_inputs to same Q,K,V\n",
        "        enc_outputs = self.pos_ffn(enc_outputs) # enc_outputs: [batch_size, seq_len, d_model]\n",
        "        return enc_outputs\n",
        "\n",
        "class BERT(nn.Module):\n",
        "    def __init__(self):\n",
        "        super(BERT, self).__init__()\n",
        "        self.embedding = Embedding()\n",
        "        self.layers = nn.ModuleList([EncoderLayer() for _ in range(n_layers)])\n",
        "        self.fc = nn.Sequential(\n",
        "            nn.Linear(d_model, d_model),\n",
        "            nn.Dropout(0.5),\n",
        "            nn.Tanh(),\n",
        "        )\n",
        "        self.classifier = nn.Linear(d_model, 2)\n",
        "        self.linear = nn.Linear(d_model, d_model)\n",
        "        self.activ2 = gelu\n",
        "        # fc2 is shared with embedding layer\n",
        "        embed_weight = self.embedding.tok_embed.weight\n",
        "        self.fc2 = nn.Linear(d_model, vocab_size, bias=False)\n",
        "        self.fc2.weight = embed_weight\n",
        "\n",
        "    def forward(self, input_ids, segment_ids, masked_pos):\n",
        "        output = self.embedding(input_ids, segment_ids) # [bach_size, seq_len, d_model]\n",
        "        enc_self_attn_mask = get_attn_pad_mask(input_ids, input_ids) # [batch_size, maxlen, maxlen]\n",
        "        for layer in self.layers:\n",
        "            # output: [batch_size, max_len, d_model]\n",
        "            output = layer(output, enc_self_attn_mask)\n",
        "        # it will be decided by first token(CLS)\n",
        "        h_pooled = self.fc(output[:, 0]) # [batch_size, d_model]\n",
        "        logits_clsf = self.classifier(h_pooled) # [batch_size, 2] predict isNext\n",
        "\n",
        "        masked_pos = masked_pos[:, :, None].expand(-1, -1, d_model) # [batch_size, max_pred, d_model]\n",
        "        h_masked = torch.gather(output, 1, masked_pos) # masking position [batch_size, max_pred, d_model]\n",
        "        h_masked = self.activ2(self.linear(h_masked)) # [batch_size, max_pred, d_model]\n",
        "        logits_lm = self.fc2(h_masked) # [batch_size, max_pred, vocab_size]\n",
        "        return logits_lm, logits_clsf\n",
        "model = BERT()\n",
        "criterion = nn.CrossEntropyLoss()\n",
        "optimizer = optim.Adadelta(model.parameters(), lr=0.001)"
      ],
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "ShYHlLr-wA_Q",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 101
        },
        "outputId": "304b6e61-d3ea-4659-e1d3-8beca0976876"
      },
      "source": [
        "for epoch in range(50):\n",
        "    for input_ids, segment_ids, masked_tokens, masked_pos, isNext in loader:\n",
        "      logits_lm, logits_clsf = model(input_ids, segment_ids, masked_pos)\n",
        "      loss_lm = criterion(logits_lm.view(-1, vocab_size), masked_tokens.view(-1)) # for masked LM\n",
        "      loss_lm = (loss_lm.float()).mean()\n",
        "      loss_clsf = criterion(logits_clsf, isNext) # for sentence classification\n",
        "      loss = loss_lm + loss_clsf\n",
        "      if (epoch + 1) % 10 == 0:\n",
        "          print('Epoch:', '%04d' % (epoch + 1), 'loss =', '{:.6f}'.format(loss))\n",
        "      optimizer.zero_grad()\n",
        "      loss.backward()\n",
        "      optimizer.step()"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Epoch: 0010 loss = 1.458610\n",
            "Epoch: 0020 loss = 0.919955\n",
            "Epoch: 0030 loss = 0.878155\n",
            "Epoch: 0040 loss = 0.747630\n",
            "Epoch: 0050 loss = 0.724138\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "VMY0ypt8wC9H",
        "colab_type": "code",
        "colab": {
          "base_uri": "https://localhost:8080/",
          "height": 269
        },
        "outputId": "2e4de07c-c757-473c-ca53-5bc983503feb"
      },
      "source": [
        "# Predict mask tokens ans isNext\n",
        "input_ids, segment_ids, masked_tokens, masked_pos, isNext = batch[1]\n",
        "print(text)\n",
        "print('================================')\n",
        "print([idx2word[w] for w in input_ids if idx2word[w] != '[PAD]'])\n",
        "\n",
        "logits_lm, logits_clsf = model(torch.LongTensor([input_ids]), \\\n",
        "                 torch.LongTensor([segment_ids]), torch.LongTensor([masked_pos]))\n",
        "logits_lm = logits_lm.data.max(2)[1][0].data.numpy()\n",
        "print('masked tokens list : ',[pos for pos in masked_tokens if pos != 0])\n",
        "print('predict masked tokens list : ',[pos for pos in logits_lm if pos != 0])\n",
        "\n",
        "logits_clsf = logits_clsf.data.max(1)[1].data.numpy()[0]\n",
        "print('isNext : ', True if isNext else False)\n",
        "print('predict isNext : ',True if logits_clsf else False)"
      ],
      "execution_count": null,
      "outputs": [
        {
          "output_type": "stream",
          "text": [
            "Hello, how are you? I am Romeo.\n",
            "Hello, Romeo My name is Juliet. Nice to meet you.\n",
            "Nice meet you too. How are you today?\n",
            "Great. My baseball team won the competition.\n",
            "Oh Congratulations, Juliet\n",
            "Thank you Romeo\n",
            "Where are you going today?\n",
            "I am going shopping. What about you?\n",
            "I am going to visit my grandmother. she is not very well\n",
            "================================\n",
            "['[CLS]', 'thank', 'you', 'romeo', '[SEP]', 'hello', 'romeo', '[MASK]', '[MASK]', 'is', 'juliet', 'nice', 'to', 'meet', 'you', '[SEP]']\n",
            "masked tokens list :  [34, 27]\n",
            "predict masked tokens list :  [34, 27]\n",
            "isNext :  False\n",
            "predict isNext :  False\n"
          ],
          "name": "stdout"
        }
      ]
    },
    {
      "cell_type": "code",
      "metadata": {
        "id": "cmlQcIJzUYVI",
        "colab_type": "code",
        "colab": {}
      },
      "source": [
        ""
      ],
      "execution_count": null,
      "outputs": []
    }
  ]
}